{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f1_score(precision, recall):\n",
    "    try:\n",
    "        return 2 * precision * recall / (precision + recall)\n",
    "    except:\n",
    "        return 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precison = 0.5\n",
    "recall = 0.5\n",
    "f1_score(precison, recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.18000000000000002"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precison = 0.1\n",
    "recall = 0.9\n",
    "f1_score(precison, recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precison = 0.0\n",
    "recall = 1.0\n",
    "f1_score(precison, recall)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn import datasets\n",
    "\n",
    "digits = datasets.load_digits()\n",
    "X = digits.data\n",
    "y = digits.target.copy()\n",
    "\n",
    "y[digits.target==9] = 1\n",
    "y[digits.target!=9] = 0\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\yinahe\\AppData\\Local\\Continuum\\anaconda3\\lib\\site-packages\\sklearn\\linear_model\\logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
      "  FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9755555555555555"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "log_reg = LogisticRegression()\n",
    "log_reg.fit(X_train, y_train)\n",
    "log_reg.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict = log_reg.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[403,   2],\n",
       "       [  9,  36]], dtype=int64)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "confusion_matrix(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9473684210526315"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score\n",
    "\n",
    "precision_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import recall_score\n",
    "\n",
    "recall_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8674698795180723"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import f1_score\n",
    "\n",
    "f1_score(y_test, y_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 10-5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-22.05700117, -33.02940957, -16.21334087, -80.3791447 ,\n",
       "       -48.25125396, -24.54005629, -44.39168773, -25.04292757,\n",
       "        -0.97829292, -19.7174399 , -66.25139191, -51.09600903,\n",
       "       -31.49348767, -46.05335761, -38.67875653, -29.80471251,\n",
       "       -37.58849546, -82.57569732, -37.81903096, -11.01165509,\n",
       "        -9.17439784, -85.13004331, -16.71617974, -46.23725224,\n",
       "        -5.32992784, -47.91762441, -11.66729524, -39.1960157 ,\n",
       "       -25.25293243, -14.3664722 , -16.99783066, -28.91904826,\n",
       "       -34.33940562, -29.47603768,  -7.85812845,  -3.82094912,\n",
       "       -24.08161558, -22.16362592, -33.61218699, -23.14023293,\n",
       "       -26.9180406 , -62.3893701 , -38.85690022, -66.77259733,\n",
       "       -20.14482853, -17.47886658, -18.06799819, -22.22224569,\n",
       "       -29.62302848, -19.73171824,   1.49552053,   8.32079827,\n",
       "       -36.29307324, -42.50732725, -25.90460192, -34.98959422,\n",
       "        -8.42010631, -50.04725431, -51.48208247,  19.88958588,\n",
       "        -8.91888462, -31.99343636, -11.66099193,  -0.47143293,\n",
       "       -49.16130314, -46.23803406, -25.05391924, -19.61348036,\n",
       "       -36.16658674,  -3.12536661,  -3.91419508, -19.06042532,\n",
       "       -21.03312808, -41.5224556 , -12.00623595, -33.89272748,\n",
       "       -35.84803017, -30.60476103, -56.51642519, -18.4546976 ,\n",
       "         4.51537007, -17.21606393, -76.65096457, -58.54520094,\n",
       "       -31.72091817, -29.90829618, -33.31897385,  -9.08750195,\n",
       "       -47.6445028 , -66.15300205, -16.95628635, -22.24904223,\n",
       "       -11.48959218, -18.10557462, -68.65400791, -47.02577635,\n",
       "       -40.11868277, -35.50211369, -17.19766251, -63.10281867,\n",
       "       -16.95444889, -55.1024096 , -28.71257167, -68.81580574,\n",
       "       -68.31018197,  -6.25934556, -25.84000062, -38.00879504,\n",
       "       -27.90914413, -15.44711517, -27.45896252, -19.59776623,\n",
       "        12.33461242, -23.03865477, -35.94461771, -30.02831815,\n",
       "       -70.06672491, -29.48728531, -52.98823525, -24.97015125,\n",
       "       -12.32842226, -48.00990099,  -2.49964896, -59.92450044,\n",
       "       -31.18113955,  -8.6572849 , -71.34894225, -57.01116921,\n",
       "       -21.09871005, -21.53852701, -69.3430687 , -18.63518493,\n",
       "       -39.91431222, -57.26579405,  -0.84508415, -21.88379715,\n",
       "       -22.64112566, -29.2126242 , -35.15697405, -20.25854026,\n",
       "       -11.40289626,   3.87277821,   6.09027067,   1.42894271,\n",
       "        -7.82708411, -39.35176617,  12.21054353, -75.1017521 ,\n",
       "       -75.38157931, -50.41808048, -11.55440106, -48.45868149,\n",
       "       -75.44074694, -29.9805692 , -64.11577067,  -7.16585109,\n",
       "        -6.52452088, -18.97255236, -33.71616873, -17.76220708,\n",
       "       -45.5938075 , -33.53733822, -34.08689163, -73.31507699,\n",
       "       -15.43458348,  12.16748889, -56.45925334,  -6.0319552 ,\n",
       "       -49.08442691, -16.54210767,  -2.05949413, -11.81040665,\n",
       "       -33.47402629, -50.7717709 , -10.62903502, -17.67500655,\n",
       "        -5.07826173, -25.25779086, -16.61516052,   3.91125798,\n",
       "       -46.75600564, -12.89881261, -25.74790992, -16.31799949,\n",
       "       -23.55099019, -83.48234577,  -6.23508453, -19.83969155,\n",
       "       -20.06233901, -26.65464787, -27.11272708, -39.6371533 ,\n",
       "       -39.81300074, -27.43662041, -24.11826958, -21.24519902,\n",
       "       -10.49822191, -19.39895263, -41.95759968, -43.6236127 ,\n",
       "       -16.06839181, -64.09610659, -24.75462044, -56.57386562,\n",
       "       -13.50011537, -30.01576521,   3.9371969 , -44.71703255,\n",
       "        -8.6936665 ,   1.58878063,  -2.7624751 , -11.91891397,\n",
       "         7.58788806,  -7.25885896, -46.73815547, -49.19661233,\n",
       "        -4.80424525, -19.61030726, -24.30539668, -48.98792204,\n",
       "       -14.98133788, -24.83600992, -16.93956321, -19.46783406,\n",
       "       -15.77206609, -17.0012166 , -39.23694881, -31.37456635,\n",
       "        -9.42198719, -71.38160642, -22.1749836 , -14.72987207,\n",
       "       -23.579862  , -34.49383473,  -1.17650006, -32.9082313 ,\n",
       "       -10.82270873, -18.26229298,  -8.29312136, -44.84196925,\n",
       "       -22.59249857, -61.73628479, -47.12976716, -65.62586183,\n",
       "       -33.36440654, -24.00481646, -29.33167521, -65.22705174,\n",
       "         1.43986417,  -4.5608828 , -25.25849628, -22.46484052,\n",
       "       -54.43073815, -16.81739829, -11.28766153, -35.2583945 ,\n",
       "        -5.57318217, -14.93090963, -70.95371454,  -6.50504563,\n",
       "        -1.22951329, -37.87549666, -23.68946477, -68.29965034,\n",
       "        14.9380316 , -62.55689349,  10.14793102, -24.44798387,\n",
       "       -32.85380751, -14.32957956, -85.68608523, -13.16399511,\n",
       "         9.27786354, -17.3272305 , -36.06508765, -17.04717415,\n",
       "       -19.71313066, -32.72639803,  -5.36345568,   7.68321079,\n",
       "         9.20404642,   5.76533756, -35.96351685, -13.02390594,\n",
       "       -54.87488091, -41.61765932,   5.93735933, -79.11922094,\n",
       "       -16.01401177, -19.72191261, -10.9633273 , -42.55204236,\n",
       "       -19.70961226, -16.20503049, -18.68732411, -17.94404259,\n",
       "        -7.17462625, -20.54727639, -16.81071326, -70.69029754,\n",
       "        -9.81778925, -32.87040451, -18.97773963, -21.37929792,\n",
       "       -25.15051838, -17.10985358, -13.5237181 , -23.76118846,\n",
       "        11.36506426, -14.50017847, -33.86304399, -13.71703356,\n",
       "       -50.52174464, -20.26633088, -56.12699982, -29.24273723,\n",
       "       -22.10083195, -31.39321864, -68.99340924, -60.34418931,\n",
       "        14.35287431,   8.69507573, -25.31397549,   2.38294889,\n",
       "         5.04572398, -19.5649317 , -59.19923489, -10.05790958,\n",
       "       -29.6621183 , -27.40194753,   6.13015332, -80.46965014,\n",
       "       -34.87541316, -49.8464651 , -36.0396559 , -48.50249557,\n",
       "       -19.96809526, -62.0577384 ,  -3.23795424, -25.32910744,\n",
       "       -65.14032159,  -9.42732343, -23.31748132,  19.38628419,\n",
       "       -18.84544623,  -4.4730859 , -13.77210705, -21.88092738,\n",
       "       -43.41392869, -51.85061708, -28.83916106, -13.90474758,\n",
       "        -2.51951282,  -6.16017921,   3.14867802, -15.33995088,\n",
       "       -41.16629369, -25.89746407,  14.30196855, -17.88818357,\n",
       "        14.67464831, -33.6579078 ,   4.82446939, -14.42659525,\n",
       "       -54.22944886, -50.49129203, -30.54686699, -38.72564866,\n",
       "       -23.46179517, -24.87719688, -14.50557245, -23.7245613 ,\n",
       "       -28.07007899, -19.6371693 , -28.66185683, -20.37693276,\n",
       "       -32.16748594, -11.15575219, -17.95925162, -24.54356192,\n",
       "       -24.60831864,  10.7369134 , -16.68578316, -38.50777749,\n",
       "       -15.87671704, -37.05244516, -15.79370448, -68.69484172,\n",
       "       -33.64811924, -43.60838042, -28.74753379,  -9.88987142,\n",
       "       -67.16452673, -33.49885792, -45.89914866, -14.36737092,\n",
       "       -38.28995583, -14.76246815, -70.44233406, -11.19630388,\n",
       "       -41.46525861, -32.38987154, -20.86071031, -27.68980551,\n",
       "       -16.06078355, -31.96315882,  -8.48422319, -22.10450644,\n",
       "       -34.06026605, -12.47053749, -36.15119732, -36.57961237,\n",
       "       -22.4615779 ,   4.47539707, -20.80768092,  -3.75031094,\n",
       "       -20.31646514, -32.67828751, -41.1070839 , -25.46019469,\n",
       "       -19.73667158, -47.83298668, -29.85783394, -45.24586627,\n",
       "       -71.65702933,  -5.93562653, -32.93703788,   1.8966422 ,\n",
       "        11.7638719 ,   7.35781728, -30.93185843, -63.94239912,\n",
       "       -23.4143365 ,  -5.43422595, -33.46410579, -24.11267351,\n",
       "       -67.49717965, -34.30056213, -34.23321455, -31.61587241,\n",
       "       -52.86794264, -22.89221089,  -8.16020811, -17.73974192,\n",
       "       -26.98681515, -32.38765342, -28.96086021, -67.2518123 ,\n",
       "       -46.49548802, -16.1128538 ])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.decision_function(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-22.05700117, -33.02940957, -16.21334087, -80.3791447 ,\n",
       "       -48.25125396, -24.54005629, -44.39168773, -25.04292757,\n",
       "        -0.97829292, -19.7174399 ])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.decision_function(X_test)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.predict(X_test)[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "decision_scores = log_reg.decision_function(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-85.68608522646575"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.min(decision_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19.8895858799022"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.max(decision_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict_2 = np.array(decision_scores >= 5, dtype='int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[404,   1],\n",
       "       [ 21,  24]], dtype=int64)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.96"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5333333333333333"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_test, y_predict_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_predict_3 = np.array(decision_scores >= -5, dtype='int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[390,  15],\n",
       "       [  5,  40]], dtype=int64)"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y_test, y_predict_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7272727272727273"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_test, y_predict_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8888888888888888"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "recall_score(y_test, y_predict_3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
